﻿using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Security.Cryptography;
using System.Text;

namespace Client.Algorithm
{
        public class Aes256
        {
            private const int KeyLength = 32;
            private const int AuthKeyLength = 64;
            private const int IvLength = 16;
            private const int HmacSha256Length = 32;
            private readonly byte[] _key;
            private readonly byte[] _authKey;

            private static readonly byte[] Salt = Encoding.ASCII.GetBytes("DcRatByqwqdanchun");

        public Aes256(string masterKey)
            {
                if (string.IsNullOrEmpty(masterKey))
                    throw new ArgumentException($"{nameof(masterKey)} can not be null or empty.");

                using (Rfc2898DeriveBytes derive = new Rfc2898DeriveBytes(masterKey, Salt, 50000))
                {
                    _key = derive.GetBytes(KeyLength);
                    _authKey = derive.GetBytes(AuthKeyLength);
                }
            }

            public string Encrypt(string input)
            {
                return Convert.ToBase64String(Encrypt(Encoding.UTF8.GetBytes(input)));
            }

            /* FORMAT
             * ----------------------------------------
             * |     HMAC     |   IV   |  CIPHERTEXT  |
             * ----------------------------------------
             *     32 bytes    16 bytes
             */
            public byte[] Encrypt(byte[] input)
            {
                if (input == null)
                    throw new ArgumentNullException($"{nameof(input)} can not be null.");

                using (var ms = new MemoryStream())
                {
                    ms.Position = HmacSha256Length; // reserve first 32 bytes for HMAC
                    using (var aesProvider = new AesCryptoServiceProvider())
                    {
                        aesProvider.KeySize = 256;
                        aesProvider.BlockSize = 128;
                        aesProvider.Mode = CipherMode.CBC;
                        aesProvider.Padding = PaddingMode.PKCS7;
                        aesProvider.Key = _key;
                        aesProvider.GenerateIV();

                        using (var cs = new CryptoStream(ms, aesProvider.CreateEncryptor(), CryptoStreamMode.Write))
                        {
                            ms.Write(aesProvider.IV, 0, aesProvider.IV.Length); // write next 16 bytes the IV, followed by ciphertext
                            cs.Write(input, 0, input.Length);
                            cs.FlushFinalBlock();

                            using (var hmac = new HMACSHA256(_authKey))
                            {
                                byte[] hash = hmac.ComputeHash(ms.ToArray(), HmacSha256Length, ms.ToArray().Length - HmacSha256Length); // compute the HMAC of IV and ciphertext
                                ms.Position = 0; // write hash at beginning
                                ms.Write(hash, 0, hash.Length);
                            }
                        }
                    }

                    return ms.ToArray();
                }
            }

            public string Decrypt(string input)
            {
                return Encoding.UTF8.GetString(Decrypt(Convert.FromBase64String(input)));
            }

            public byte[] Decrypt(byte[] input)
            {
                if (input == null)
                    throw new ArgumentNullException($"{nameof(input)} can not be null.");

                using (var ms = new MemoryStream(input))
                {
                    using (var aesProvider = new AesCryptoServiceProvider())
                    {
                        aesProvider.KeySize = 256;
                        aesProvider.BlockSize = 128;
                        aesProvider.Mode = CipherMode.CBC;
                        aesProvider.Padding = PaddingMode.PKCS7;
                        aesProvider.Key = _key;

                        // read first 32 bytes for HMAC
                        using (var hmac = new HMACSHA256(_authKey))
                        {
                            var hash = hmac.ComputeHash(ms.ToArray(), HmacSha256Length, ms.ToArray().Length - HmacSha256Length);
                            byte[] receivedHash = new byte[HmacSha256Length];
                            ms.Read(receivedHash, 0, receivedHash.Length);

                            if (!AreEqual(hash, receivedHash))
                                throw new CryptographicException("Invalid message authentication code (MAC).");
                        }

                        byte[] iv = new byte[IvLength];
                        ms.Read(iv, 0, IvLength); // read next 16 bytes for IV, followed by ciphertext
                        aesProvider.IV = iv;

                        using (var cs = new CryptoStream(ms, aesProvider.CreateDecryptor(), CryptoStreamMode.Read))
                        {
                            byte[] temp = new byte[ms.Length - IvLength + 1];
                            byte[] data = new byte[cs.Read(temp, 0, temp.Length)];
                            Buffer.BlockCopy(temp, 0, data, 0, data.Length);
                            return data;
                        }
                    }
                }
            }

            /// <summary>
            /// Compares two byte arrays for equality.
            /// </summary>
            /// <param name="a1">Byte array to compare</param>
            /// <param name="a2">Byte array to compare</param>
            /// <returns>True if equal, else false</returns>
            /// <remarks>
            /// Assumes that the byte arrays have the same length.
            /// This method is safe against timing attacks.
            /// </remarks>
            [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.NoOptimization)]
            private bool AreEqual(byte[] a1, byte[] a2)
            {
                bool result = true;
                for (int i = 0; i < a1.Length; ++i)
                {
                    if (a1[i] != a2[i])
                        result = false;
                }
                return result;
            }
        }
    }